import os
import random

import torch
import numpy as np
import chromadb

from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
from langchain.retrievers.merger_retriever import MergerRetriever
from langchain.text_splitter import SpacyTextSplitter, RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from FlagEmbedding import FlagLLMReranker, LayerWiseFlagLLMReranker, FlagReranker


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    # torch.backends.cudnn.deterministic = True


class RAGSearch:
    def __init__(self, configuration):
        self.lotr = None
        self.reranker_model = None
        self.em1_retriever = None
        self.em2_retriever = None
        self.em3_retriever = None
        self.retriever_list = []
        setup_seed(42)

        self.knowledge_base_dir_texts = None

        self.configuration = configuration

        self.em1 = self.configuration['encoding_model1']
        self.em2 = self.configuration['encoding_model2']
        self.em3 = self.configuration['encoding_model3']
        self.reranker = self.configuration['reranker_model']

        self.knowledge_base_dir = self.configuration['knowledge_base_dir']
        self.use_device = self.configuration['device']
        self.knowledge_base_vector_path = self.configuration['knowledge_base_vector_path']
        self.knowledge_base_reload_flag = self.configuration['is_reload_vector_db']
        self.each_model_top_size = self.configuration['each_model_top_size']

        self.prepare_knowledge_database()
        self.generate_embeddings_and_vectorstores()
        self.create_merger_and_reranker()
        print("Encoding Model: {}, {}, {}".format(self.em1, self.em2, self.em3))
        print("RAG Search initialized")

    def prepare_knowledge_database(self):
        text_loader_kwargs = {'autodetect_encoding': True}
        knowledge_database_loader = DirectoryLoader(
            self.knowledge_base_dir, glob="./*.txt", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)
        knowledge_database_documents = knowledge_database_loader.load()
        knowledge_database_text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=512)
        self.knowledge_base_dir_texts = knowledge_database_text_splitter.split_documents(knowledge_database_documents)

        print(f"Prepare Knowledge Database ......")

    def generate_embeddings_and_vectorstores(self):

        client_settings = chromadb.config.Settings(
            is_persistent=True,
            persist_directory=self.knowledge_base_vector_path,
            anonymized_telemetry=False,
        )
        if self.knowledge_base_reload_flag and os.path.exists(self.knowledge_base_vector_path):
            # 如果 reload 为 False 且向量存储目录已存在,则直接加载向量存储
            em1_vectorstore = Chroma(persist_directory=os.path.join(self.knowledge_base_vector_path, "em1"),
                                     client_settings=client_settings)
            em2_vectorstore = Chroma(persist_directory=os.path.join(self.knowledge_base_vector_path, "em2"),
                                     client_settings=client_settings)
            em3_vectorstore = Chroma(persist_directory=os.path.join(self.knowledge_base_vector_path, "em3"),
                                     client_settings=client_settings)
            self.em1_retriever = em1_vectorstore.as_retriever(search_type="mmr",
                                                              search_kwargs={"k": self.each_model_top_size,
                                                                             "include_metadata": True})
            self.em2_retriever = em2_vectorstore.as_retriever(search_type="mmr",
                                                              search_kwargs={"k": self.each_model_top_size,
                                                                             "include_metadata": True})
            self.em3_retriever = em3_vectorstore.as_retriever(search_type="mmr",
                                                              search_kwargs={"k": self.each_model_top_size,
                                                                             "include_metadata": True})

            self.retriever_list.append(self.em1_retriever)
            self.retriever_list.append(self.em2_retriever)
            self.retriever_list.append(self.em3_retriever)

        else:
            if self.em1 is not None:
                em1_embeddings = HuggingFaceEmbeddings(model_name=self.em1, model_kwargs={"device": self.use_device},
                                                       encode_kwargs={'normalize_embeddings': True})
                em1_vectorstore = Chroma.from_documents(self.knowledge_base_dir_texts, em1_embeddings,
                                                        client_settings=client_settings,
                                                        collection_name="em1", collection_metadata={"hnsw": "cosine"},
                                                        persist_directory=os.path.join(self.knowledge_base_vector_path,
                                                                                       "em1"))
                self.em1_retriever = em1_vectorstore.as_retriever(search_type="mmr",
                                                                  search_kwargs={"k": self.each_model_top_size,
                                                                                 "include_metadata": True})
                self.retriever_list.append(self.em1_retriever)
            else:
                pass

            if self.em2 is not None:
                em2_embeddings = HuggingFaceBgeEmbeddings(model_name=self.em2, model_kwargs={"device": self.use_device},
                                                          encode_kwargs={'normalize_embeddings': True})
                em2_vectorstore = Chroma.from_documents(self.knowledge_base_dir_texts, em2_embeddings,
                                                        client_settings=client_settings,
                                                        collection_name="em2",
                                                        collection_metadata={"hnsw": "cosine"},
                                                        persist_directory=os.path.join(self.knowledge_base_vector_path,
                                                                                       "em2"))
                self.em2_retriever = em2_vectorstore.as_retriever(search_type="mmr",
                                                                  search_kwargs={"k": self.each_model_top_size,
                                                                                 "include_metadata": True})
                self.retriever_list.append(self.em2_retriever)
            else:
                pass
            if self.em3 is not None:
                em3_embeddings = HuggingFaceEmbeddings(model_name=self.em3, model_kwargs={"device": self.use_device},
                                                       encode_kwargs={'normalize_embeddings': True})

                em3_vectorstore = Chroma.from_documents(self.knowledge_base_dir_texts, em3_embeddings,
                                                        client_settings=client_settings,
                                                        collection_name="em3", collection_metadata={"hnsw": "plm"},
                                                        persist_directory=os.path.join(self.knowledge_base_vector_path,
                                                                                       "em3"))

                self.em3_retriever = em3_vectorstore.as_retriever(search_type="mmr",
                                                                  search_kwargs={"k": self.each_model_top_size,
                                                                                 "include_metadata": True})
                self.retriever_list.append(self.em3_retriever)
            else:
                pass

    def create_merger_and_reranker(self):

        self.lotr = MergerRetriever(retrievers=self.retriever_list)
        if "m3" in self.reranker:
            self.reranker_model = FlagReranker(self.reranker, use_fp16=True)
        elif 'gemma' in self.reranker:
            self.reranker_model = FlagLLMReranker(self.reranker, use_fp16=True)
        elif 'minicpm' in self.reranker:
            self.reranker_model = LayerWiseFlagLLMReranker(self.reranker, use_fp16=True)
        else:
            print("custom reranker model, please rewrite rag_search.py create_merger_and_reranker to support it.")

    def search_item(self, search_text, related_search_text="", topN=10):

        unique_knowledge_content = []
        score_pairs = []
        filenames = []
        query = f"what is the most related definition of {search_text}? {related_search_text}"

        for chunk in self.lotr.get_relevant_documents(query):
            knowledge_content = chunk.page_content
            # Retrieve the file name from metadata
            file_name = chunk.metadata.get('source', 'unknown')
            file_name = os.path.basename(file_name).replace('.txt', '')
            filenames.append(file_name)
            score_pairs.append([query, knowledge_content])
        scores = self.reranker_model.compute_score(score_pairs)

        scores = np.array(scores)

        partitioned_indices = np.argpartition(scores, -topN)
        topN_indices = partitioned_indices[-topN:]

        top_scores = scores[topN_indices]
        sorted_indices_top_scores = np.argsort(top_scores)[::-1]

        final_topN_indices = topN_indices[sorted_indices_top_scores]

        for idx in final_topN_indices:
            knowledge_content = score_pairs[idx][1]
            file_name = filenames[idx]
            if knowledge_content not in unique_knowledge_content:
                unique_knowledge_content.append(knowledge_content)
        return unique_knowledge_content[:topN]
